# Copyright (c) 2021 Red Hat, Inc.
# Ported to Python 3, fixed linting errors, switched from class to
# module/function interface, refined naming and documentation.
#
# Copyright (c) 2012 George Prekas <prekgeo@yahoo.com>
# Taken from https://github.com/prekageo/optistate@2c8bbb0
#
# Based on code from Robert Dick <dickrp@eecs.umich.edu> and Pat Maupin
# <pmaupin@gmail.com>. Most of the original code was re-written for
# performance reasons.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
"""
Boolean function operations.

DEFINITIONS

minterm
    A product (conjunction) of function variables where each variable
    appears at most once.

    See https://en.wikipedia.org/wiki/Canonical_normal_form#Minterm

    A minterm is represented by a tuple containing a bitmap specifying which
    input variables are appearing in the product as is, and a bitmap
    specifying which input variables don't otherwise appear negated. Or, in
    other words, the second bitmap specifies variables which don't appear in
    the product at all.

    That is, if "positive" is the bitmap of positive terms, and "negative" is
    the bitmap of negative terms, then each tuple contains:

        (positive, ~(positive | negative))

    Conversely, if you have a (first, second) tuple, then:

        positive = first
        negative = ~(first | second)

    E.g. A&B would be expressed as [(3, 0)], A|B would be [(1, 2), (2, 1)],
    !A - [(0, 0)], !A&!B - [(0, 0)], and !A|!B - [(0, 2), (0, 1)].

minterm index
    Essentially, a bitmap of particular input variable values for a function,
    identifying that set of inputs uniquely.

    Say, a function has three inputs: A, B, and C. If we choose to keep them
    in the same order, a bitmap for A=1, B=0, and C=1 would be 0b101.
    Similarly, A=0, B=1, C=1 would then be represented as 0b011.

minexpr
    Either a boolean representing a constant value,
    or a list of sum-of-product minterms representing a boolean function.
"""
import math


def minimize(numvars, ones, dont_cares):
    """
    Transform a boolean function into a minimal sum-of-products solution.

    Transform a boolean function into a minimal sum-of-products solution,
    using the Quine-McCluskey algorithm and the Petrick's method.

    Args:
        numvars:    Number of function variables.
        ones:       A list of all minterm indices for which the function
                    evaluates to 1 (true).
                    E.g. if the function in question would only evaluate to 1
                    (true) for the example cases listed under "minterm index",
                    then the minterm index list would be [0b101, 0b011], or
                    [5, 3].
        dont_cares: A list of all minterm indices for which the function value
                    doesn't matter.

    Returns:
        A number representing the complexity of the result (smaller ==
        simpler), and the resulting "minexpr" representing the minimized
        function/constant.
    """
    # Handle special case for functions that always evaluate to True or False.
    if len(ones) == 0:
        return 0, False
    if len(ones) + len(dont_cares) == 1 << numvars:
        return 0, True
    primes = find_primes(numvars, ones + dont_cares)
    return unate_cover(numvars, list(primes), ones)


def find_primes(numvars, cubes):
    """
    Find all prime implicants of a boolean function.

    Args:
        numvars:    Number of function variables.
        cubes:      A list of indices for the minterms for which the function
                    evaluates to 1 or don't-care.

    Returns:
        A set of prime minterms.
    """
    # Sorry, I have no idea what most of those contain yet,
    # pylint: disable=invalid-name
    sigma = []
    for i in range(numvars+1):
        sigma.append(set())
    for i in cubes:
        sigma[_bitcount(i)].add((i, 0))

    primes = set()
    while sigma:
        nsigma = []
        redundant = set()
        for c1, c2 in zip(sigma[:-1], sigma[1:]):
            nc = set()
            for a in c1:
                for b in c2:
                    m = minterm_merge(a, b)
                    if m is not None:
                        nc.add(m)
                        redundant |= set([a, b])
            nsigma.append(nc)
        primes |= set(c for cubes in sigma for c in cubes) - redundant
        sigma = nsigma
    return primes


def unate_cover(numvars, primes, ones):
    """
    Use prime implicants to find essential prime implicants of a function.

    Use the prime implicants to find the essential prime implicants of the
    function, as well as other prime implicants that are necessary to cover
    the function. Use the Petrick's method, which is a technique for
    determining all minimum sum-of-products solutions from a prime implicant
    chart.

    Args:
        numvars:    Number of function variables.
        primes:     The prime implicant minterms to minimize.
        ones:       A list of indices for the minterms for which the function
                    evaluates to 1 (true).

    Returns:
        A number representing the complexity of the result (smaller ==
        simpler), and a list of minterms comprising the minimized expression.
    """
    # Sorry, don't know what's going on here yet,
    # pylint: disable=too-many-locals,too-many-branches
    chart = []
    for one in ones:
        column = []
        for i, prime in enumerate(primes):
            if (one & (~prime[1])) == prime[0]:
                column.append(i)
        chart.append(column)

    covers = []
    if len(chart) > 0:
        covers = [set([i]) for i in chart[0]]
    for i in range(1, len(chart)):
        new_covers = []
        for cover in covers:
            for prime_index in chart[i]:
                # Sorry, don't know what this contains yet,
                # pylint: disable=invalid-name
                x = set(cover)
                x.add(prime_index)
                append = True
                for j in range(len(new_covers)-1, -1, -1):
                    if x <= new_covers[j]:
                        del new_covers[j]
                    elif x > new_covers[j]:
                        append = False
                if append:
                    new_covers.append(x)
        covers = new_covers

    min_complexity = math.inf
    for cover in covers:
        primes_in_cover = [primes[prime_index] for prime_index in cover]
        complexity = calculate_complexity(numvars, primes_in_cover)
        if complexity < min_complexity:
            min_complexity = complexity
            result = primes_in_cover

    return min_complexity, result


def calculate_complexity(numvars, minterms):
    """
    Calculate the complexity of the given function.

    Calculate the complexity of the given function. The complexity is
    calculated based on the following rules:
        A NOT gate adds 1 to the complexity.
        A n-input AND or OR gate adds n to the complexity.

    Args:
        numvars:    Number of function variables.
        minterms:   A list of minterms that form the function.

    Returns:
        An integer representing the complexity of the function, with smaller
        values being lower complexity.
    """
    complexity = len(minterms)
    if complexity == 1:
        complexity = 0
    mask = (1 << numvars) - 1
    for minterm in minterms:
        masked = ~minterm[1] & mask
        term_complexity = _bitcount(masked)
        if term_complexity == 1:
            term_complexity = 0
        complexity += term_complexity
        complexity += _bitcount(~minterm[0] & masked)

    return complexity


def minexpr_format(variables, minexpr):
    """
    Format a minimized expression as sum-of-products expression string.

    Args:
        variables:  A list of variable names.
        minexpr:    A "minexpr" representing the function.

    Returns:
        A string representing the function using operators '&', '|' and '!',
        or constants '0' or '1'.
    """
    if isinstance(minexpr, bool):
        return str(int(minexpr))
    minterms = minexpr
    or_terms = []
    for minterm in minterms:
        and_terms = []
        for j, variable in enumerate(variables):
            if minterm[0] & (1 << j):
                and_terms.append(variable)
            elif not minterm[1] & (1 << j):
                and_terms.append(f'!{variable}')
        or_terms.append('&'.join(and_terms))
    return ' | '.join(or_terms)


def _bitcount(number):
    """
    Count number of bits set in a number.

    Args:
        number: The number to count the set bits in.

    Returns:
        The number of set bits in the number.
    """
    bits = 0
    while number > 0:
        bits += number & 1
        number >>= 1
    return bits


def _is_power_of_two_or_zero(number):
    """
    Check if a number is zero or a power of two.

    Check if a number is zero or a power of two, i.e. determine if the number
    has at most 1 bit set.

    Args:
        number: The number to check.

    Returns:
        True if the number is a power of two, or zero. False otherwise.
    """
    return (number & (~number + 1)) == number


def minterm_merge(minterm_x, minterm_y):
    """
    Combine two minterms.

    Args:
        minterm_x:  The first minterm to combine.
        minterm_y:  The second minterm to combine.

    Returns:
        The combined minterm.
    """
    if minterm_x[1] != minterm_y[1]:
        return None
    # Sorry, don't know what this contains yet,
    # pylint: disable=invalid-name
    y = minterm_x[0] ^ minterm_y[0]
    if not _is_power_of_two_or_zero(y):
        return None
    return (minterm_x[0] & minterm_y[0], minterm_x[1] | y)
